__author__ = 'Qi'
# Created by on 11/3/21.
import torch
import numpy as np
from torch.utils.data.sampler import Sampler
import torch.nn as nn

from libauc.losses import AUCMLoss
# from loss import logistic_loss, sigmoid_loss

class AUPRCSampler(Sampler):

    def __init__(self, labels, batchSize, posNum=1):
        # positive class: minority class
        # negative class: majority class

        self.labels = labels
        self.posNum = posNum
        self.batchSize = batchSize

        self.clsLabelList = np.unique(labels)
        self.dataDict = {}
        self.ret = []

        for label in self.clsLabelList:
            self.dataDict[str(label)] = []

        for i in range(len(self.labels)):
            self.dataDict[str(self.labels[i])].append(i)

 


    def __iter__(self):
        minority_data_list = self.dataDict[str(1)]
        majority_data_list = self.dataDict[str(0)]

        np.random.shuffle(minority_data_list)
        np.random.shuffle(majority_data_list)
        
        # In every iteration : sample 1(posNum) positive sample(s), and sample batchSize - 1(posNum) negative samples
        if len(minority_data_list) // self.posNum  > len(majority_data_list)//(self.batchSize - self.posNum): # At this case, we go over the all positive samples in every epoch.
            # extend the length of majority_data_list from  len(majority_data_list) to len(minority_data_list)* (batchSize-posNum)
            majority_data_list.extend(np.random.choice(majority_data_list, len(minority_data_list) // self.posNum * (self.batchSize - self.posNum) - len(majority_data_list), replace=True).tolist())

        elif len(minority_data_list) // self.posNum  < len(majority_data_list)//(self.batchSize - self.posNum): # At this case, we go over the all negative samples in every epoch.
            # extend the length of minority_data_list from len(minority_data_list) to len(majority_data_list)//(batchSize-posNum) + 1

            minority_data_list.extend(np.random.choice(minority_data_list, len(majority_data_list) // (self.batchSize - self.posNum)*self.posNum - len(minority_data_list), replace=True).tolist())

        self.ret = []
        for i in range(len(minority_data_list) // self.posNum):
            self.ret.extend(minority_data_list[i*self.posNum:(i+1)*self.posNum])

            startIndex = i*(self.batchSize - self.posNum)
            endIndex = (i+1)*(self.batchSize - self.posNum)
            self.ret.extend(majority_data_list[startIndex:endIndex])       

        return iter(self.ret)


    def __len__ (self):
        return len(self.ret)

class pAUC_CVaR(nn.Module):
    def __init__(self, pos_length, num_neg, threshold=1, gamma=0.2, eta=0.1, loss_type = 'sh'):
        '''
        :param threshold: margin for squred hinge loss
        '''
        super(pAUC_CVaR, self).__init__()
        self.gamma = round(gamma*num_neg)/num_neg
        self.eta = eta
        self.num_neg = num_neg
        self.pos_length = pos_length
        self.lambda_pos = torch.tensor([0.0]*pos_length).view(-1, 1).cuda()
        self.threshold = threshold
        self.loss_type = loss_type
        print('The loss type is :', self.loss_type)

    def update_smoothing(self, decay_factor):
        self.eta = self.eta/decay_factor

    def forward(self, y_pred, y_true, index_p, index_n): 
        f_ps = y_pred[y_true == 1].view(-1)
        f_ns = y_pred[y_true == 0].view(-1) 

        vec_dat = f_ns
        mat_data = vec_dat.repeat(len(f_ps), 1)

        f_ps = f_ps.view(-1, 1)

        # 3*1 - 3*64 ==> 3*64
        if self.loss_type == 'sh':
            neg_loss = torch.max(self.threshold - (f_ps - mat_data), torch.zeros_like(mat_data)) ** 2 
        elif self.loss_type == 'sq':
            neg_loss = (self.threshold - (f_ps - mat_data)) ** 2

        loss = neg_loss
        p = loss > self.lambda_pos[index_p]

        if f_ps.size(0) == 1:
            self.lambda_pos[index_p] = self.lambda_pos[index_p]-self.eta/self.pos_length*(1 - p.sum()/(self.gamma*self.num_neg))
        else:
            self.lambda_pos[index_p] = self.lambda_pos[index_p]-self.eta/self.pos_length*(1 - p.sum(dim=1, keepdim=True)/(self.gamma*self.num_neg))
            

        p.detach_()

        loss = torch.mean(p * loss) / self.gamma

        return loss

class pAUC_KL(nn.Module):
    def __init__(self, pos_length, threshold=1, beta=0.9, Lambda=1, loss_type = 'sh'):
        '''
        :param threshold: margin for squred hinge loss
        '''
        super(pAUC_KL, self).__init__()
        self.beta = beta
        self.Lambda = Lambda
        self.u_pos = torch.tensor([0.0]*pos_length).view(-1, 1).cuda()
        self.threshold = threshold
        self.loss_type = loss_type
        print('The loss type is :', self.loss_type)

    def update_smoothing(self, decay_factor):
        self.beta = self.beta/decay_factor

    def forward(self, y_pred, y_true, index_p, index_n): 
        f_ps = y_pred[y_true == 1].reshape(-1)
        f_ns = y_pred[y_true == 0].reshape(1,-1)

        vec_dat = f_ns
        mat_data = vec_dat.repeat(len(f_ps), 1)


        f_ps = f_ps.view(-1, 1)


        # 3*1 - 3*64 ==> 3*64
        if self.loss_type == 'sh':
            neg_loss = torch.max(self.threshold - (f_ps - mat_data), torch.zeros_like(mat_data)) ** 2 
        elif self.loss_type == 'sq':
            neg_loss = (self.threshold - (f_ps - mat_data)) ** 2 

        loss = neg_loss
        #max_loss = torch.max(loss)
        exp_loss = torch.exp(loss/self.Lambda)

        if f_ps.size(0) == 1:
            self.u_pos[index_p] = (1 - self.beta) * self.u_pos[index_p] + self.beta * (exp_loss.mean())
        else:
            self.u_pos[index_p] = (1 - self.beta) * self.u_pos[index_p] + self.beta * (exp_loss.mean(1, keepdim=True))

        p = exp_loss/self.u_pos[index_p]
        p.detach_()
        #print(self.u_pos[index_p])
        #print(p)

        loss = torch.mean(p * loss)

        return loss

class pAUC_fai(nn.Module):
    def __init__(self, threshold=1, gamma=100, p_type = 'poly', loss_type = 'sh', eps=1e-6):
        '''
        :param threshold: margin for squred hinge loss
        '''
        super(pAUC_fai, self).__init__()
        self.gamma = gamma
        self.p_type = p_type
        self.threshold = threshold
        self.eps = eps
        self.loss_type = loss_type
        print('The loss type is :', self.loss_type)



    def forward(self, y_pred, y_true): 
        f_ps = y_pred[y_true == 1].reshape(-1)
        f_ns = y_pred[y_true == 0].reshape(1,-1)

        vec_dat = f_ns
        mat_data = vec_dat.repeat(len(f_ps), 1)

        f_ps = f_ps.view(-1, 1)

        # 3*1 - 3*64 ==> 3*64

        if self.loss_type == 'sh':
            neg_loss = torch.max(self.threshold - (f_ps - mat_data), torch.zeros_like(mat_data)) ** 2 
        elif self.loss_type == 'sq':
            neg_loss = (self.threshold - (f_ps - mat_data)) ** 2 

        loss = neg_loss

        ####weight
        if self.p_type == 'poly': 
            p = torch.pow(f_ns+self.eps, 1/(self.gamma-1))
        elif self.p_type == 'exp':
            p = 1 - torch.exp(- self.gamma * f_ns)
        else:
            raise ValueError
        # print('p',p,p*loss)

        loss = torch.mean(p * loss)

        return loss

class pAUC_mini(nn.Module):
    def __init__(self, num_neg, threshold=1, gamma=0.2, loss_type = 'sh'):
        '''
        :param threshold: margin for squred hinge loss
        '''
        super(pAUC_mini, self).__init__()
        self.gamma = round(gamma*num_neg)/num_neg
        self.num_neg = num_neg
        self.threshold = threshold
        self.loss_type = loss_type
        print('Num negative :', self.num_neg)
        print('The loss type is :', self.loss_type)


    def forward(self, y_pred, y_true): 
        f_ps = y_pred[y_true == 1].view(-1)
        f_ns = y_pred[y_true == 0].view(-1)


        partial_arg = torch.topk(f_ns, int(self.num_neg*self.gamma), sorted = False)[1]
        vec_dat = f_ns[partial_arg]
        #print(f_ns.shape)
        #print(vec_dat.shape)
        mat_data = vec_dat.repeat(len(f_ps), 1)
        #print(mat_data.shape)
        #exit()

        f_ps = f_ps.view(-1, 1)

        # 3*1 - 3*64 ==> 3*64

        if self.loss_type == 'sh':
            neg_loss = torch.max(self.threshold - (f_ps - mat_data), torch.zeros_like(mat_data)) ** 2 
        elif self.loss_type == 'sq':
            neg_loss = (self.threshold - (f_ps - mat_data)) ** 2
            
        loss = neg_loss

        loss = torch.mean(loss)

        return loss


class SOAPLOSS(nn.Module):
    def __init__(self, data_length, threshold=1, gamma=0.9, loss_type = 'sqh'):
        '''
        :param threshold: margin for squred hinge loss
        '''
        super(SOAPLOSS, self).__init__()
        self.u_all = torch.tensor([0.0]*data_length).view(-1, 1).cuda()
        self.u_pos = torch.tensor([0.0]*data_length).view(-1, 1).cuda()
        self.threshold = threshold
        self.loss_type = loss_type
        self.gamma = gamma
        print('The loss type is :', self.loss_type)

    def update_smoothing(self, decay_factor):
        self.gamma = self.gamma/decay_factor

    def forward(self, y_pred, y_true, index_p, index_n): 
        f_ps = y_pred[y_true == 1].view(-1)
        f_ns = y_pred[y_true == 0].view(-1)

        vec_dat = torch.cat((f_ps, f_ns), 0)
        mat_data = vec_dat.repeat(len(f_ps), 1)

       #  print(mat_data.shape)

        f_ps = f_ps.view(-1, 1)

        neg_mask = torch.ones_like(mat_data)
        neg_mask[:, 0:f_ps.size(0)] = 0

        pos_mask = torch.zeros_like(mat_data)
        pos_mask[:, 0:f_ps.size(0)] = 1

        # test_tmp = f_ps- mat_data
        # print(f_ps.size(), mat_data.size(), test_tmp.size())

        # 3*1 - 3*64 ==> 3*64

        if self.loss_type == 'sqh':

            neg_loss = torch.max(self.threshold - (f_ps - mat_data), torch.zeros_like(mat_data)) ** 2 * neg_mask
            pos_loss = torch.max(self.threshold - (f_ps - mat_data), torch.zeros_like(mat_data)) ** 2 * pos_mask

        elif self.loss_type == 'lgs':

            neg_loss = logistic_loss(f_ps, mat_data, self.threshold) * neg_mask
            pos_loss = logistic_loss(f_ps, mat_data, self.threshold) * pos_mask

        elif self.loss_type == 'sgm':
            neg_loss = sigmoid_loss(f_ps, mat_data, self.threshold) * neg_mask
            pos_loss = sigmoid_loss(f_ps, mat_data, self.threshold) * pos_mask


        loss = pos_loss + neg_loss


        if f_ps.size(0) == 1:

            self.u_pos[index_p] = (1 - self.gamma) * self.u_pos[index_p] + self.gamma * (pos_loss.mean())
            self.u_all[index_p] = (1 - self.gamma) * self.u_all[index_p] + self.gamma * (loss.mean())
        else:
            # print(self.u_all[index_p], loss.size(), loss.sum(1, keepdim = 1))
            self.u_all[index_p] = (1 - self.gamma) * self.u_all[index_p] + self.gamma * (loss.mean(1, keepdim=True))
            self.u_pos[index_p] = (1 - self.gamma) * self.u_pos[index_p] + self.gamma * (pos_loss.mean(1, keepdim=True))



        p = (self.u_pos[index_p] - (self.u_all[index_p]) * pos_mask) / (self.u_all[index_p] ** 2)


        p.detach_()

        loss = torch.mean(p * loss)
        # loss = loss.mean()

        return loss


class P_PUSH(nn.Module):
    def __init__(self, pos_length, threshold=1, gamma=0.9, poly=2, loss_type = 'sh'):
        '''
        :param threshold: margin for squred hinge loss
        '''
        super(P_PUSH, self).__init__()

        self.poly = poly
        self.gamma = gamma
        self.u_pos = torch.tensor([0.0]*pos_length).view(-1, 1).cuda()
        self.threshold = threshold
        self.loss_type = loss_type
        print('The loss type is :', self.loss_type)

    def update_smoothing(self, decay_factor):
        self.gamma = self.gamma/decay_factor

    def forward(self, y_pred, y_true, index_p, index_n): 
        f_ps = y_pred[y_true == 1].reshape(-1)
        f_ns = y_pred[y_true == 0].reshape(1,-1)

        vec_dat = f_ns
        mat_data = vec_dat.repeat(len(f_ps), 1)

        f_ps = f_ps.view(-1, 1)

        # 3*1 - 3*64 ==> 3*64


        if self.loss_type == 'sh':
            neg_loss = torch.max(self.threshold - (f_ps - mat_data), torch.zeros_like(mat_data)) ** 2 
        elif self.loss_type == 'sq':
            neg_loss = (self.threshold - (f_ps - mat_data)) ** 2

        loss = neg_loss

        if f_ps.size(0) == 1:

            self.u_pos[index_p] = (1 - self.gamma) * self.u_pos[index_p] + self.gamma * (loss.mean())
        else:
            self.u_pos[index_p] = (1 - self.gamma) * self.u_pos[index_p] + self.gamma * (loss.mean(1, keepdim=True))

        p = self.poly*(self.u_pos[index_p]**(self.poly-1))

        p.detach_()

        loss = torch.mean(p * loss)

        return loss

class AUC_pair(nn.Module):
    def __init__(self, threshold=1, loss_type = 'sh'):
        '''
        :param threshold: margin for squred hinge loss
        '''
        super(AUC_pair, self).__init__()
        self.threshold = threshold
        self.loss_type = loss_type
        print('The loss type is :', self.loss_type)


    def forward(self, y_pred, y_true): 
        f_ps = y_pred[y_true == 1].reshape(-1)
        f_ns = y_pred[y_true == 0].reshape(1,-1)

        vec_dat = f_ns
        mat_data = vec_dat.repeat(len(f_ps), 1)

        f_ps = f_ps.view(-1, 1)

        # 3*1 - 3*64 ==> 3*64

        if self.loss_type == 'sh':
            neg_loss = torch.max(self.threshold - (f_ps - mat_data), torch.zeros_like(mat_data)) ** 2 
        elif self.loss_type == 'sq':
            neg_loss = (self.threshold - (f_ps - mat_data)) ** 2
            
        loss = neg_loss

        loss = torch.mean(loss)

        return loss



#############   two way pAUC    #######

class pAUC_KL_two(nn.Module):
    def __init__(self, pos_length, Lambda, tau, threshold=1.0, beta_1=0.9, beta_2=0.9, loss_type = 'sqh'):
        '''
        :param threshold: margin for squred hinge loss
        '''
        super(pAUC_KL_two, self).__init__()
        self.beta_1 = beta_1
        self.beta_2 = beta_2
        self.Lambda = Lambda
        self.tau = tau
        self.u_pos = torch.tensor([0.0]*pos_length).view(-1, 1).cuda()
        self.w = 0.0
        self.threshold = threshold
        self.loss_type = loss_type
        print('The loss type is :', self.loss_type)

    def update_smoothing(self, decay_factor):
        self.beta_1 = self.beta_1/decay_factor
        self.beta_2 = self.beta_2/decay_factor

    def forward(self, y_pred, y_true, index_p, index_n): 
        f_ps = y_pred[y_true == 1].reshape(-1)
        f_ns = y_pred[y_true == 0].reshape(1,-1)

        vec_dat = f_ns
        mat_data = vec_dat.repeat(len(f_ps), 1)

        f_ps = f_ps.view(-1, 1)

        # 3*1 - 3*64 ==> 3*64

        if self.loss_type == 'sqh':

            neg_loss = torch.max(self.threshold - (f_ps - mat_data), torch.zeros_like(mat_data)) ** 2 

        loss = neg_loss
        exp_loss = torch.exp(loss/self.Lambda)

        if f_ps.size(0) == 1:

            self.u_pos[index_p] = (1 - self.beta_1) * self.u_pos[index_p] + self.beta_1 * (exp_loss.mean())
        else:
            self.u_pos[index_p] = (1 - self.beta_1) * self.u_pos[index_p] + self.beta_1 * (exp_loss.mean(1, keepdim=True))

        self.w = (1 - self.beta_2) * self.w + self.beta_2 * (torch.pow(self.u_pos[index_p], self.Lambda/self.tau).mean())
        
        
        
        p = torch.pow(self.u_pos[index_p], self.Lambda/self.tau - 1) * exp_loss/self.w
        # print(torch.pow(self.u_pos[index_p], self.Lambda/self.tau - 1),'tt',exp_loss, self.w)

        p.detach_()
        loss = torch.mean(p * loss)

        return loss

class pAUC_fai_two(nn.Module):
    def __init__(self, threshold=1.0, gamma=100.0, p_type = 'poly', loss_type = 'sh', eps=1e-6):
        '''
        :param threshold: margin for squred hinge loss
        '''
        super(pAUC_fai_two, self).__init__()
        self.gamma = gamma
        self.p_type = p_type
        self.threshold = threshold
        self.eps = eps
        self.loss_type = loss_type
        print('The loss type is :', self.loss_type)


    def forward(self, y_pred, y_true): 
        f_ps = y_pred[y_true == 1].reshape(-1)
        f_ns = y_pred[y_true == 0].reshape(1,-1)


        vec_dat = f_ns
        mat_data = vec_dat.repeat(len(f_ps), 1)

        f_ps = f_ps.view(-1, 1)

        # 3*1 - 3*64 ==> 3*64

        if self.loss_type == 'sh':
            neg_loss = torch.max(self.threshold - (f_ps - mat_data), torch.zeros_like(mat_data)) ** 2 
        elif self.loss_type == 'sq':
            neg_loss = (self.threshold - (f_ps - mat_data)) ** 2 

        loss = neg_loss

        ####weight
        if self.p_type == 'poly':  
            col = torch.pow(1-f_ps+self.eps, 1/(self.gamma-1))         
            row = torch.pow(f_ns+self.eps, 1/(self.gamma-1))
        elif self.p_type == 'exp':
            col = 1 - torch.exp(- self.gamma * (1-f_ps))
            row = 1 - torch.exp(- self.gamma * f_ns)
        else:
            raise ValueError
        
        ###element-wise is also ok
        p = torch.mm(col,row)
        #print(f_ns.shape)
        #print(row.shape)
        #print(p.shape)
        #print(p)
        #exit()
        # print('p',p,p*loss)
        loss = torch.mean(p * loss)

        return loss

class pAUC_mini_two(nn.Module):
    def __init__(self, num_pos, num_neg, threshold=1.0, gamma=0.2, loss_type = 'sqh'):
        '''
        :param threshold: margin for squred hinge loss
        '''
        super(pAUC_mini_two, self).__init__()
        self.gamma = gamma # round(gamma*num_neg)/num_neg
        self.num_pos = num_pos
        self.num_neg = num_neg
        self.threshold = threshold
        self.loss_type = loss_type
        print('The loss type is :', self.loss_type)


    def forward(self, y_pred, y_true): 
        f_ps = y_pred[y_true == 1].view(-1)
        f_ns = y_pred[y_true == 0].view(-1)


        partial_arg_pos = torch.topk(f_ps, round(self.gamma*self.num_pos),largest=False, sorted = False)[1]
        partial_arg_neg = torch.topk(f_ns, round(self.gamma*self.num_neg),largest=True, sorted = False)[1]
        
        vec_dat = f_ns[partial_arg_neg]
        mat_data = vec_dat.repeat(len(partial_arg_pos), 1)

        selected_ps = f_ps[partial_arg_pos].view(-1, 1)


        if self.loss_type == 'sqh':

            neg_loss = torch.max(self.threshold - (selected_ps - mat_data), torch.zeros_like(mat_data)) ** 2 
            
        loss = neg_loss
        # print(partial_arg_pos,'tt',mat_data,'sss',loss)
        loss = torch.mean(loss)

        return loss

class PUSH2P2LOSS(nn.Module):
    def __init__(self, threshold, batch_size, data_length, loss_type = 'sqh'):
        '''
        :param threshold: margin for squred hinge loss
        '''
        super(PUSH2P2LOSS, self).__init__()
        self.u_neg = torch.tensor([0.0]*data_length).view(-1, 1).cuda()
        self.u_pos = torch.tensor([0.0]*data_length).view(-1, 1).cuda()
        self.threshold = threshold
        self.loss_type = loss_type
        print('The loss type is :', self.loss_type)


    def forward(self,f_ps, f_ns, index_p, index_n, gamma):
        f_ps = f_ps.view(-1)
        f_ns = f_ns.view(-1)

        vec_dat = f_ps
        mat_data = vec_dat.repeat(len(f_ns), 1)

        f_ns = f_ns.view(-1, 1)

        # 3*1 - 3*64 ==> 3*64

        if self.loss_type == 'sqh':

            pos_loss = torch.max(self.threshold - (mat_data - f_ns), torch.zeros_like(mat_data)) ** 2 

        loss = pos_loss
        if f_ns.size(0) == 1:

            self.u_neg[index_n] = (1 - gamma) * self.u_neg[index_n] + gamma * (loss.mean())
        else:
            self.u_neg[index_n] = (1 - gamma) * self.u_neg[index_n] + gamma * (loss.mean(1, keepdim=True))

        p = 2*self.u_neg[index_n]

        p.detach_()

        # loss = torch.sum(p * loss)
        loss = torch.mean(p * loss)

        if torch.isnan(loss):
            print('####',index_p,index_n)

        return loss


class PUSH3P2LOSS(nn.Module):
    def __init__(self, threshold, batch_size, data_length, loss_type = 'sqh'):
        '''
        :param threshold: margin for squred hinge loss
        '''
        super(PUSH3P2LOSS, self).__init__()
        self.u_neg = torch.tensor([0.0]*data_length).view(-1, 1).cuda()
        self.u_pos = torch.tensor([0.0]*data_length).view(-1, 1).cuda()
        self.threshold = threshold
        self.loss_type = loss_type
        print('The loss type is :', self.loss_type)


    def forward(self,f_ps, f_ns, index_p,index_n, gamma):
        f_ps = f_ps.view(-1)
        f_ns = f_ns.view(-1)

        vec_dat_p = f_ps
        vec_dat_n = f_ns

        mat_data_p = vec_dat_p.repeat(len(f_ns), 1)        
        mat_data_n = vec_dat_n.repeat(len(f_ps), 1)

        f_ns = f_ns.view(-1, 1)
        f_ps = f_ps.view(-1, 1)

        # 3*1 - 3*64 ==> 3*64

        if self.loss_type == 'sqh':

            pos_loss = torch.max(self.threshold - (mat_data_p - f_ns), torch.zeros_like(mat_data_p)) ** 2 
            neg_loss = torch.max(self.threshold - (f_ps - mat_data_n), torch.zeros_like(mat_data_n)) ** 2 

        ####compute neg_loss for push1
        if f_ps.size(0) == 1:
            self.u_pos[index_p] = (1 - gamma) * self.u_pos[index_p] + gamma * (neg_loss.mean())
        else:
            self.u_pos[index_p] = (1 - gamma) * self.u_pos[index_p] + gamma * (neg_loss.mean(1, keepdim=True))

        p_neg = 2*self.u_pos[index_p]
        p_neg.detach_()
        neg_loss = torch.mean(p_neg * neg_loss)

        ####compute pos_loss for push2
        if f_ns.size(0) == 1:
            self.u_neg[index_n] = (1 - gamma) * self.u_neg[index_n] + gamma * (pos_loss.mean())
        else:
            self.u_neg[index_n] = (1 - gamma) * self.u_neg[index_n] + gamma * (pos_loss.mean(1, keepdim=True))

        p_pos = 2*self.u_neg[index_n]
        p_pos.detach_()
        pos_loss = torch.mean(p_pos * pos_loss)


        loss = 0.5*neg_loss + 0.5*pos_loss


        return loss

    def __init__(self, threshold, batch_size, data_length, loss_type = 'sqh'):
        '''
        :param threshold: margin for squred hinge loss
        '''
        super(PUSH3P3LOSS, self).__init__()
        self.u_neg = torch.tensor([0.0]*data_length).view(-1, 1).cuda()
        self.u_pos = torch.tensor([0.0]*data_length).view(-1, 1).cuda()
        self.threshold = threshold
        self.loss_type = loss_type
        print('The loss type is :', self.loss_type)


    def forward(self,f_ps, f_ns, index_p,index_n, gamma):
        f_ps = f_ps.view(-1)
        f_ns = f_ns.view(-1)

        vec_dat_p = f_ps
        vec_dat_n = f_ns

        mat_data_p = vec_dat_p.repeat(len(f_ns), 1)        
        mat_data_n = vec_dat_n.repeat(len(f_ps), 1)

        f_ns = f_ns.view(-1, 1)
        f_ps = f_ps.view(-1, 1)

        # 3*1 - 3*64 ==> 3*64

        if self.loss_type == 'sqh':

            pos_loss = torch.max(self.threshold - (mat_data_p - f_ns), torch.zeros_like(mat_data_p)) ** 2 
            neg_loss = torch.max(self.threshold - (f_ps - mat_data_n), torch.zeros_like(mat_data_n)) ** 2 

        ####compute neg_loss for push1
        if f_ps.size(0) == 1:
            self.u_pos[index_p] = (1 - gamma) * self.u_pos[index_p] + gamma * (neg_loss.mean())
        else:
            self.u_pos[index_p] = (1 - gamma) * self.u_pos[index_p] + gamma * (neg_loss.mean(1, keepdim=True))

        p_neg = 3*(self.u_pos[index_p]**2)
        p_neg.detach_()
        neg_loss = torch.mean(p_neg * neg_loss)

        ####compute pos_loss for push2
        if f_ns.size(0) == 1:
            self.u_neg[index_n] = (1 - gamma) * self.u_neg[index_n] + gamma * (pos_loss.mean())
        else:
            self.u_neg[index_n] = (1 - gamma) * self.u_neg[index_n] + gamma * (pos_loss.mean(1, keepdim=True))

        p_pos = 3*(self.u_neg[index_n]**2)
        p_pos.detach_()
        pos_loss = torch.mean(p_pos * pos_loss)


        loss = 0.5*neg_loss + 0.5*pos_loss


        return loss

